// TODO: NOTE: I am not convinced that all of the derivative stuff still works, now that all data statistics and preprocessing have been removed!!!
// TODO: NOTE: ... this must now be handled by stats++ preprocessor.

// TODO: should rely on the external activation functions (stats++) ... instead of redefining everything here
// TODO: ... with the above, should not have to constantly do a case check since the activation function can become a general pointer
// TODO: Consistency with all of the machine learning packages should be made --- using a single set of activation functions, etc.

// TODO: NOTE: The evolutionary algorithm training was removed ... and thinking about it, perhaps that functionality can/should come from the outside.

// TODO: The DataSet really needs to be rethinked --- and it needs to be made consistent between machine learning algorithms. I really don't know if we want to just stick with Matrix<>, or if we could use the extra power/flexibility of a DataSet.

// TODO: Need to decide 110% on the name of NeuralNet, and then update the case.

// TODO: For training (in general), need to come up with a better method than storing all weights at every epoch --- does not seem a practical, scalable method

// TODO: I do not know if the recurrent part of these networks work.

// TODO: I think that a structure with training parameters would simplify things. This idea extends also to RBM/DBN.

// TODO: Because this is a graphical-model implementation, should make it so a network can be initialized by specifying architecture with linkings, etc.

// TODO: Need to verify that fully connect (maybe will go away) and recurrent still works.

// TODO: Should implement some "early stopping" techniques (see the paper "Early Stopping: But When?").
//
// TOOD: NOTE: ... Perhaps there could be similar (or the same) techniques for overtraining (this is important for ensembles, for example).

#ifndef STATSxx_MACHINE_LEARNING_NEURALNET_HPP
#define STATSxx_MACHINE_LEARNING_NEURALNET_HPP


// STL
#include <tuple>                                         // std::tuple<>
#include <vector>

// Boost
#include <boost/serialization/serialization.hpp>         // boost::serialization::
#include <boost/serialization/vector.hpp>                // serialize std::vector<>

// jScience
#include "jScience/linalg.hpp"                           // Matrix<>, Vector<>

// stats++
#include "statsxx/dataset.hpp"                           // DataSet
#include "statsxx/machine_learning/Learner.hpp"          // Learner
#include "statsxx/machine_learning/NeuralNet/Link.hpp"   // LINK
#include "statsxx/machine_learning/NeuralNet/Neuron.hpp" // NEURON


//=========================================================
// NEURAL NETWORK
//=========================================================
//
// NOTE: This is derived from a Learner object, so that it works with the Ensemble class.
//
class NEURAL_NET : public Learner
{

    // TODO: Should get rid of the m_ prefixes for all variable names.

    // TODO: Subroutines should take constant parameters, not just by reference.

public:

    // TODO: The public check for this should probably go to a utility subroutine.
    bool m_isClassif;  // WHETHER THE NETWORK IS FOR CLASSIFICATION


    NEURAL_NET();
    ~NEURAL_NET();

    // ----- INITIALIZATION -----

    // TODO: The init() subroutine probably isn't being used as appropriately as it should.
    // TODO: ... In fact, it probably doesn't even need to be used --- could put into the constructor.
    void   init();

    // TODO: Maybe only have one create_MLP() subroutine that takes the architecture vector.

    void   create_MLP(
                      int ni,
                      int no,
                      int nl,
                      std::vector<int> nhn,
                      bool fully_connect,
                      bool recurrent,
                      int  af_type, // 0 == logistic, 1 == tanh, 2 == softplus
                      bool isClass,
                      const std::vector<double> &w = std::vector<double>()
                      );

    void   create_MLP(
                      std::vector<int> architecture,
                      bool recurrent,
                      int  af_type, // 0 == logistic, 1 == tanh, 2 == softplus
                      bool isClass,
                      const Matrix<double> &W,
                      const Vector<double> &bias
                      );


    // ----- TRAINING -----

    double train(
                 const int      method,
                 //------
                 const int      nepoch_min,
                 const int      nepoch_max,
                 // -----
                 const bool     early_stopping,
                 // -----
                 const double   lr,
                 const double   lr_min,
                 const double   lr_max,
                 // -----
                 const double   momentum,
                 // -----
                 const double   weight_penalty,
                 // -----
                 const double   qrprop_u,
                 const double   qrprop_d,
                 // -----
                 const double   scg_lambda,
                 const double   scg_sigma,
                 const double   scg_convg_iterfrac,
                 const double   scg_rk_tol,
                 // -----
                 const DataSet &TS_tr,
                 const DataSet &TS_val,
                 const DataSet &TS_gen,
                 // -----
                 const int      npts_per_batch,
                 // -----
                 const bool     silent
                 );

    // ----- EVALUATION -----

    // TODO: In both of the following, I am not sure that both a single and wrapper subroutine is needed (clutters the basic class up).

    void   evaluate(const std::vector<std::vector<double>> &input, std::vector<double> &output);
    std::tuple<double, double, std::vector<std::vector<double>>> parse_TS(DataSet ds);

    double deriv(const int neuron_idx, const int inp_idx);
    std::vector<std::vector<double>> derivs(std::vector<double> input);

    // ----- UTILITY -----

    std::vector<double> get_weights();

    // TODO: This should probably go to a utility subroutine. This would clear up this class, and could be called from the outside.
    // TODO: ... If this is the case, calling NeuralNet.hpp at the top should load up this class, utility subroutines, etc.
    //
    // TODO: NOTE: Maybe a subroutine should be provided to optimize only a single weight.
    // TODO: NOTE: ... I can envision several uses for this --- right now just EA training uses it.
    void   optimize_link_weights();

    void   assign_weights(const std::vector<double> &w);

private:

    int                              m_ninp, m_nout;  // number of input and output

    // << JMM >> :: I could probably change the m_bias_neurons vector to just a value, since bias neurons should ALWAYS have a value of 1 -- but I will leave the vector for now, in case there is something I have not foreseen
    std::vector<NEURON>              m_neurons;
    std::vector<int>                 m_inp_neurons;
    std::vector<int>                 m_bias_neurons;
    std::vector<int>                 m_out_neurons;

    std::vector<LINK>                m_links;
    std::vector<std::vector<int>>    m_links_in;
    std::vector<std::vector<int>>    m_links_out;

    std::vector<bool>                m_lkupTbl_neurPtToNeur_for;
    std::vector<bool>                m_lkupTbl_neurPtToNeur_back;
    std::vector<std::vector<int>>    m_lkupTbl_connections;
    std::vector<std::vector<int>>    layer_neuron_idx;

    // ----- INITIALIZATION -----

    void   create_lookup_tables();
    void   get_layer_neuron_idx();

    bool   neurons_pt_to_neurons(const std::vector<int> &neurons, bool _forward);
    bool   check_for_connection(NEURON &n1, NEURON &n2);
    void   find_source_neurons(std::vector<int> &neurons, bool _forward);

    // ----- TRAINING -----

    void   backpropagation(
                           const int      max_epochs,
                           // -----
                           const bool     early_stopping,
                           // -----
                           const double   lr,
                           // -----
                           const double   momentum,
                           // -----
                           const double   weight_penalty,
                           // -----
                           const DataSet &ds_tr,
                           const DataSet &ds_val,
                           // -----
                           const int      npts_per_batch
                           );

    void   QRprop(
                  const int      max_epoch,
                  // -----
                  const bool     early_stopping,
                  // -----
                  const double   gamma_ini,
                  const double   gamma_min,
                  const double   gamma_max,
                  // -----
                  const double   u,
                  const double   d,
                  // -----
                  const DataSet &ds_tr,
                  const DataSet &ds_val
                  );

    // TODO: I do not know if it is worth leaving in CG.
    // TODO: ... If so, it NEEDS updating.
    void   CG(
              const int      max_epochs,
              // -----
              const bool     early_stopping,
              // -----
              const DataSet &ds_tr,
              const DataSet &ds_val
              );

    void   SCG(
               const int      MINITER,
               const int      MAXITER,
               // -----
               const bool     early_stopping,
               // -----
               double         lambda_k,
               const double   sigma,
               // -----
               const DataSet &TS_tr,
               const DataSet &TS_val,
               // -----
               const double   CONVG_ITERFRAC,
               const double   rkTOL,
               // -----
               const bool     silent
               );

    void   calulate_Ederiv(DataSet TS, std::vector<double> &dEdw);

    // ----- EVALUATION -----

    void   assign_output(std::vector<double> &output);

    // ----- PROPAGATION -----

    // forward propagation
    void   load_input(std::vector<double> input);
    void   forward_prop(std::vector<double> &input);

    // backward propagation
    void   load_output_errors(const std::vector<double> &actual_out, const std::vector<double> &calc_out);
    void   assign_recurrent_errors(std::vector<int> &top_neurons);
    void   backward_prop(int i, const std::vector<double> &actual_out, const std::vector<double> &calc_out);

    // propagation
    void   propagate_signal(NEURON &n, std::vector<int> &ID_hit_list, bool _forward);
    void   activate_neurons(std::vector<int> &neurons, bool _forward);

    // ----- UTILITY -----

    void   clear_memory();
    void   clear_errors();
    void   deactivate_network();

    // ----- (Boost) SERIALIZATION -----
    friend class boost::serialization::access;

    template<class Archive>
    void serialize(
                   Archive            &ar,
                   const unsigned int  version
                   )
    {
//        ar & boost::serialization::base_object<Learner>(*this);

        ar & this->m_isClassif;

        ar & this->m_ninp;
        ar & this->m_nout;

        ar & this->m_neurons;
        ar & this->m_inp_neurons;
        ar & this->m_bias_neurons;
        ar & this->m_out_neurons;

        ar & this->m_links;
        ar & this->m_links_in;
        ar & this->m_links_out;

        ar & this->m_lkupTbl_neurPtToNeur_for;
        ar & this->m_lkupTbl_neurPtToNeur_back;
        ar & this->m_lkupTbl_connections;
        ar & this->layer_neuron_idx;
    }

};

// ----- INITIALIZATION -----
#include "statsxx/machine_learning/NeuralNet/NeuralNet.cpp"

#include "statsxx/machine_learning/NeuralNet/create_MLP.cpp"

#include "statsxx/machine_learning/NeuralNet/create_lookup_tables.cpp"
#include "statsxx/machine_learning/NeuralNet/get_layer_neuron_idx.cpp"

#include "statsxx/machine_learning/NeuralNet/neurons_pt_to_neurons.cpp"
#include "statsxx/machine_learning/NeuralNet/check_for_connection.cpp"
#include "statsxx/machine_learning/NeuralNet/find_source_neurons.cpp"

// ----- TRAINING -----
#include "statsxx/machine_learning/NeuralNet/train.cpp"

#include "statsxx/machine_learning/NeuralNet/backpropagation.cpp"
#include "statsxx/machine_learning/NeuralNet/QRprop.cpp"
#include "statsxx/machine_learning/NeuralNet/CG.cpp"
#include "statsxx/machine_learning/NeuralNet/SCG.cpp"

#include "statsxx/machine_learning/NeuralNet/calculate_Ederiv.cpp"

// ----- EVALUATION -----
#include "statsxx/machine_learning/NeuralNet/evaluate.cpp"
#include "statsxx/machine_learning/NeuralNet/parse_TS.cpp"

#include "statsxx/machine_learning/NeuralNet/deriv.cpp"
#include "statsxx/machine_learning/NeuralNet/derivs.cpp"

#include "statsxx/machine_learning/NeuralNet/assign_output.cpp"

// ----- PROPAGATION -----
#include "statsxx/machine_learning/NeuralNet/load_input.cpp"
#include "statsxx/machine_learning/NeuralNet/forward_prop.cpp"

#include "statsxx/machine_learning/NeuralNet/load_output_errors.cpp"
#include "statsxx/machine_learning/NeuralNet/assign_recurrent_errors.cpp"
#include "statsxx/machine_learning/NeuralNet/backward_prop.cpp"

#include "statsxx/machine_learning/NeuralNet/propagate_signal.cpp"
#include "statsxx/machine_learning/NeuralNet/activate_neurons.cpp"

// ----- UTILITY -----
#include "statsxx/machine_learning/NeuralNet/get_weights.cpp"
#include "statsxx/machine_learning/NeuralNet/assign_weights.cpp"


#endif
